from torch.quantization import fuse_modules 
from utils.model import YoloModel, Conv2d 
import torch 



class YoloModelFuse(YoloModel):
    def fuse_model(self):
        for m in self.modules():
            if type(m) == Conv2d:
                fuse_modules(m, ["layers.0", "layers.1"], inplace=True) 
    def forward(self, x):
        B = 1
        h0 = self.base0(x) 
        h1 = self.base1(h0) 
        h2 = self.base2(h1) 
        
        y2, cat1 = self.yolo2(h2) 
        h1 = torch.cat([h1, cat1], dim=1)
        y1, cat0 = self.yolo1(h1) 
        h0 = torch.cat([h0, cat0], dim=1) 
        y0 = self.yolo0(h0)
        
        y0 = y0.reshape([B, 3, 85, 52, 52]).permute(0, 1, 3, 4, 2)
        y1 = y1.reshape([B, 3, 85, 26, 26]).permute(0, 1, 3, 4, 2)
        y2 = y2.reshape([B, 3, 85, 13, 13]).permute(0, 1, 3, 4, 2)
        anch0 = torch.tensor([[10,13], [16,30], [33,23]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
        anch1 = torch.tensor([[30,61], [62,45], [59,119]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
        anch2 = torch.tensor([[116,90], [156,198], [373,326]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
        yv, xv = torch.meshgrid(torch.arange(52), torch.arange(52))
        grid0 = torch.stack([xv, yv], 2).reshape((1, 1, 52, 52, 2)).float().to(x.device)
        yv, xv = torch.meshgrid(torch.arange(26), torch.arange(26))
        grid1 = torch.stack([xv, yv], 2).reshape((1, 1, 26, 26, 2)).float().to(x.device)
        yv, xv = torch.meshgrid(torch.arange(13), torch.arange(13))
        grid2 = torch.stack([xv, yv], 2).reshape((1, 1, 13, 13, 2)).float().to(x.device)

        y0[..., 0:2] = (y0[..., 0:2].sigmoid() + grid0) * 8  # xy
        y0[..., 2:4] = torch.exp(y0[..., 2:4]) * anch0 # wh
        y0[..., 4:] = y0[..., 4:].sigmoid()
        y0 = y0.reshape(B, -1, 85)

        y1[..., 0:2] = (y1[..., 0:2].sigmoid() + grid1) * 16  # xy
        y1[..., 2:4] = torch.exp(y1[..., 2:4]) * anch1 # wh
        y1[..., 4:] = y1[..., 4:].sigmoid()
        y1 = y1.reshape(B, -1, 85)

        y2[..., 0:2] = (y2[..., 0:2].sigmoid() + grid2) * 32  # xy
        y2[..., 2:4] = torch.exp(y2[..., 2:4]) * anch2 # wh
        y2[..., 4:] = y2[..., 4:].sigmoid()
        y2 = y2.reshape(B, -1, 85)
        y = torch.cat([y0, y1, y2], dim=1)
        return y 

device = torch.device("cpu")
model = YoloModelFuse() 
model.load_state_dict(torch.load("ckpt/model.pt", map_location=device)) 
model.eval() 
model.fuse_model() 
torch.jit.save(torch.jit.script(model), "ckpt/model.jit")
input_names = ["image"]
output_names = ["output"]
dummy_input = torch.randn([1, 3, 416, 416])
torch.onnx.export(model, dummy_input, 
"ckpt/model.onnx", 
verbose=True, input_names=input_names, 
output_names=output_names, opset_version=11)
